{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Perturbed-Attention Guidance\n",
    "\n",
    "[Project](https://cvlab-kaist.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/cvlab-kaist/Perturbed-Attention-Guidance)\n",
    "\n",
    "This implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG). This script was contributed by [Hyoungwon Cho](https://github.com/HyoungwonCho) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote).\n",
    "\n",
    "PAG Parameters-\n",
    "\n",
    "`pag_scale` : guidance scale of PAG (ex: 5.0)\n",
    "\n",
    "`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.31.0)\n",
      "Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.6.0)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.5.0)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.26.3)\n",
      "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2.32.3)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.4.5)\n",
      "Requirement already satisfied: Pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (9.0.0)\n",
      "Requirement already satisfied: typing-extensions>=4.10.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2024.9.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.5.8)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.2.1.3)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (10.3.5.147)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.6.1.9)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.3.1.170)\n",
      "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (0.6.2)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: triton==3.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.2.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
      "Requirement already satisfied: packaging>=20.9 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (6.0.2)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.23.2->diffusers) (4.66.6)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2024.8.30)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install diffusers torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7c1b4a389334744bc910405fa88d0d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "pipeline.py:   0%|          | 0.00/72.5k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69d09ab9990c4c1c87364ce2384f7eac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:2917: FutureWarning: `LoraLoaderMixin` is deprecated and will be removed in version 1.0.0. LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead.\n",
      "  deprecate(\"LoraLoaderMixin\", \"1.0.0\", deprecation_message)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "baf8796ee51c4f74a595d6fd8d0b067a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98004aef328f4e4cbc26fa066a09ed9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "\n",
    "from accelerate.utils import set_seed\n",
    "\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.utils import load_image, make_image_grid\n",
    "from diffusers.utils.torch_utils import randn_tensor\n",
    "\n",
    "pipe = StableDiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n",
    "    custom_pipeline=\"hyoungwoncho/sd_perturbed_attention_guidance\",\n",
    "    torch_dtype=torch.float16\n",
    ")\n",
    "\n",
    "device = \"cuda\"\n",
    "pipe = pipe.to(device)\n",
    "\n",
    "pag_scale = 5.0\n",
    "pag_applied_layers_index = ['m0']\n",
    "\n",
    "batch_size = 4\n",
    "seed = 10\n",
    "\n",
    "base_dir = \"./results/\"\n",
    "grid_dir = base_dir + \"/pag\" + str(pag_scale) + \"/\"\n",
    "\n",
    "if not os.path.exists(grid_dir):\n",
    "    os.makedirs(grid_dir)\n",
    "\n",
    "set_seed(seed)\n",
    "\n",
    "latent_input = randn_tensor(shape=(batch_size,4,64,64), generator=None, device=device, dtype=torch.float16)\n",
    "\n",
    "output_baseline = pipe(\n",
    "    \"\",\n",
    "    width=512,\n",
    "    height=512,\n",
    "    num_inference_steps=50,\n",
    "    guidance_scale=0.0,\n",
    "    pag_scale=0.0,\n",
    "    pag_applied_layers_index=pag_applied_layers_index,\n",
    "    num_images_per_prompt=batch_size,\n",
    "    latents=latent_input\n",
    ").images\n",
    "\n",
    "output_pag = pipe(\n",
    "    \"\",\n",
    "    width=512,\n",
    "    height=512,\n",
    "    num_inference_steps=50,\n",
    "    guidance_scale=0.0,\n",
    "    pag_scale=5.0,\n",
    "    pag_applied_layers_index=pag_applied_layers_index,\n",
    "    num_images_per_prompt=batch_size,\n",
    "    latents=latent_input\n",
    ").images\n",
    "\n",
    "grid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)\n",
    "grid_image.save(grid_dir + \"sample.png\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
